import os
import cv2
import numpy as np
import torch
from face_detection import FaceAlignment, LandmarksType
from mmpose.apis import inference_topdown, init_model
from mmpose.structures import merge_data_samples

# Initialize device
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)

# Initialize face detection model - pass device as string
fa = FaceAlignment(LandmarksType._2D, flip_input=False, device=device_str)

# Initialize pose estimation model for landmarks
config_file = '/root/AvatarSync/face_cropping/utils/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = '/root/AvatarSync/face_cropping/models/dw-ll_ucoco_384.pth'
pose_model = init_model(config_file, checkpoint_file, device=device)

def process_face_image(image_path, target_size=320, margin_ratio=0.3, upperbondrange=0):
    """
    Process the input image to extract the face region using landmark detection.
    Then resize it to exactly target_size x target_size.
    
    Args:
        image_path: Path to the input image
        target_size: Target size for the output image (default: 320)
        margin_ratio: Margin ratio to add around the face (default: 0.3)
        upperbondrange: Adjustment for the upper bound of face bbox (default: 0)
        
    Returns:
        face_tensor: Tensor of the cropped and resized face (target_size x target_size)
        face_metadata: Dictionary with metadata about the face extraction
        original_image_rgb: Original image in RGB format
    """
    # Read the image
    original_image = cv2.imread(image_path)
    if original_image is None:
        raise ValueError(f"Cannot load image: {image_path}")
    
    original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    original_height, original_width = original_image_rgb.shape[:2]
    
    # Get face landmarks using mmpose
    results = inference_topdown(pose_model, original_image)
    results = merge_data_samples(results)
    
    # Check if any keypoints were detected
    if not hasattr(results.pred_instances, 'keypoints') or len(results.pred_instances.keypoints) == 0:
        # No face detected, use whole image
        x1, y1, x2, y2 = 0, 0, original_width, original_height
        use_whole_image = True
    else:
        # Extract face landmarks (indices 23-91 are face keypoints)
        keypoints = results.pred_instances.keypoints
        face_landmarks = keypoints[0][23:91].astype(np.int32)
        
        # Get face bounding box using face detection
        bbox = fa.get_detections_for_batch(np.asarray([original_image]))
        
        if bbox[0] is None:
            # No face detected by face detection, use landmarks
            half_face_coord = face_landmarks[29]
            
            # Apply adjustment if specified
            if upperbondrange != 0:
                half_face_coord[1] = upperbondrange + half_face_coord[1]
                
            # Calculate upper bound based on face geometry
            half_face_dist = np.max(face_landmarks[:,1]) - half_face_coord[1]
            upper_bound = half_face_coord[1] - half_face_dist
            
            # Create face bounding box
            x1 = np.min(face_landmarks[:, 0])
            y1 = int(upper_bound)
            x2 = np.max(face_landmarks[:, 0])
            y2 = np.max(face_landmarks[:,1])
            
            # Check if the bounding box is valid
            if y2-y1 <= 0 or x2-x1 <= 0 or x1 < 0:
                # Invalid bounding box, use whole image
                x1, y1, x2, y2 = 0, 0, original_width, original_height
                use_whole_image = True
            else:
                use_whole_image = False
        else:
            # Use the face detection bounding box
            x1, y1, x2, y2 = bbox[0]
            use_whole_image = False
    
    # Crop the face from the original image
    face_image = original_image_rgb[y1:y2, x1:x2]
    face_height, face_width = face_image.shape[:2]
    
    resized_face = cv2.resize(face_image, (target_size, target_size))
    
    resized_height, resized_width = resized_face.shape[:2]
    
    face_tensor = torch.from_numpy(resized_face).permute(2, 0, 1).float() / 127.5 - 1.0
    face_tensor = face_tensor.unsqueeze(0)
    
    face_metadata = {
        "face_location": (y1, x2, y2, x1),  # (top, right, bottom, left)
        "original_size": (original_height, original_width),
        "use_whole_image": use_whole_image,
        "crop_size": (face_height, face_width),
        "resized_shape": (resized_height, resized_width),
        "target_size": target_size
    }
    
    return face_tensor, face_metadata, original_image_rgb

def merge_face_back(generated_face, face_metadata, original_image, target_size=None):
    """
    Merge the generated face back into the original image.
    First resize the generated face back to the original face dimensions,
    then place it in the original location.
    
    Args:
        generated_face: Tensor of the generated face (target_size x target_size)
        face_metadata: Dictionary with metadata about the face extraction
        original_image: Original image
        target_size: Target size for the final image (width, height)
        
    Returns:
        merged_tensor: Tensor of the merged image
    """
    if face_metadata.get("use_whole_image", False):
        # If we used the whole image, just resize the generated face to match original
        face_np = ((generated_face[0].permute(1, 2, 0).numpy() + 1) * 127.5).astype(np.uint8)
        
        if target_size:
            resized_face = cv2.resize(face_np, (target_size[0], target_size[1]))
        else:
            resized_face = cv2.resize(face_np, (original_image.shape[1], original_image.shape[0]))
            
        merged_tensor = torch.from_numpy(resized_face).permute(2, 0, 1).float() / 127.5 - 1.0
        merged_tensor = merged_tensor.unsqueeze(0)
        return merged_tensor
    
    # Convert generated face to numpy and denormalize to [0, 255]
    face_np = generated_face[0].permute(1, 2, 0).cpu().numpy()
    face_np = ((face_np + 1) * 127.5).astype(np.uint8)
    
    # Extract metadata
    crop_size = face_metadata.get("crop_size", (0, 0))
    
    # Resize back to original crop size (height, width)
    face_resized = cv2.resize(face_np, (crop_size[1], crop_size[0]))
    
    # Get face location from metadata
    top, right, bottom, left = face_metadata["face_location"]
    
    # Create an output image with the original size
    output_image = original_image.copy()
    
    # Place the resized face back into the original position
    output_image[top:bottom, left:right] = face_resized
    
    # Resize to target size if specified
    if target_size:
        final_output = cv2.resize(output_image, (target_size[0], target_size[1]))
    else:
        final_output = output_image
    
    # Convert to tensor format (normalize to [-1, 1])
    merged_tensor = torch.from_numpy(final_output).permute(2, 0, 1).float() / 127.5 - 1.0
    merged_tensor = merged_tensor.unsqueeze(0)
    
    return merged_tensor

# Utility function for testing
def preprocess_image(image_path, target_size=None):
    """
    Simple image preprocessing without face detection.
    Used as a fallback when face detection is not used.
    """
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Cannot load image: {image_path}")
    
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    if target_size:
        image_rgb = cv2.resize(image_rgb, (target_size, target_size))
    
    image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 127.5 - 1.0
    image_tensor = image_tensor.unsqueeze(0)
    
    return image_tensor

# For testing
if __name__ == "__main__":
    test_image_path = "test.png"
    face_tensor, face_metadata, original_image = process_face_image(test_image_path, target_size=320)
    print("Face tensor shape:", face_tensor.shape)
    print("Face metadata:", face_metadata)
    
    # Save the cropped face
    face_np = ((face_tensor[0].permute(1, 2, 0).numpy() + 1) * 127.5).astype(np.uint8)
    cv2.imwrite("cropped_face.png", cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR))
    
    # Test merging back
    merged_tensor = merge_face_back(face_tensor, face_metadata, original_image)
    merged_np = ((merged_tensor[0].permute(1, 2, 0).numpy() + 1) * 127.5).astype(np.uint8)
    cv2.imwrite("merged_face.png", cv2.cvtColor(merged_np, cv2.COLOR_RGB2BGR))
    
    print("Test completed. Check 'cropped_face.png' and 'merged_face.png'")